import os
import matplotlib.pyplot as plt
import numpy as np


def dict2value(dict_str):
    return float(dict_str.split()[1])


def read_log(file_path, is_skip=False):
    sup_losses, unsup_losses, mixed_losses, pseudo_acces, ulb_ratios = [], [], [], [], []
    add_skip, ema_skip = False, False
    for line in open(file_path, 'r').readlines():
        line = line.strip()
        if 'Additional logging info:' in line:
            if is_skip:
                if add_skip:
                    add_skip = False
                    continue
                else: add_skip = True
                
            blocks = line.replace(' Additional logging info:', ',').split(',')
            pseudo_acc = dict2value(blocks[2])
            pseudo_acces.append(round(pseudo_acc, 4))
        elif 'iteration USE_EMA: True,' in line:
            if is_skip:
                if ema_skip:
                    ema_skip = False
                    continue
                else: ema_skip = True
            
            blocks = line.split(',')
            sup_loss, mixed_loss, unsup_loss, ratio = dict2value(blocks[2]), dict2value(blocks[3]), dict2value(blocks[5]), dict2value(blocks[7])
            sup_losses.append(sup_loss)
            mixed_losses.append(mixed_loss)
            unsup_losses.append(unsup_loss)
            ulb_ratios.append(ratio)
    return sup_losses, unsup_losses, mixed_losses, pseudo_acces, ulb_ratios


def double_y_zhexian(left_y1, left_y2, right_y1, right_y2, save_name):
    length = len(left_y1)
    x = np.arange(length)  
    
    fig, ax1 = plt.subplots()  
    # ax2 = ax1.twinx()  
    
    # line1 = ax1.plot(x, left_y1, color='dodgerblue', marker=None)
    # line2 = ax1.plot(x, left_y2, color='royalblue', marker=None) 
    # line3 = ax1.plot(x, right_y1, color='tomato', marker=None) 
    # line4 = ax1.plot(x, right_y2, color='darkgreen', marker=None)
    # ax1.set_ylim(0, 1)
    
    line1 = ax1.plot(x, left_y1, color='dodgerblue', marker=None)
    line2 = ax1.plot(x, left_y2, color='royalblue', marker=None)
    line3 = ax1.plot(x, right_y1, color='darkgreen', marker=None)
    line4 = ax1.plot(x, right_y2, color='tomato', marker=None) 
    ax1.set_ylim(0.0, 2.0)
    # ax2.set_ylim(0.0, 2.0)
    # ax1.set_xlim(0, 270)
    
    plt.show()
    plt.savefig(save_name)


file1_path = './main_jhc_220_2_mixed_16384_contrast_improved.txt'
sup_losses_1, unsup_losses_1, mixed_losses_1, pseudo_acces_1, ulb_ratios_1 = read_log(file1_path)
file2_path = './main_jhc_16384_wo_mix.txt'
sup_losses_2, unsup_losses_2, mixed_losses_2, pseudo_acces_2, ulb_ratios_2 = read_log(file2_path, is_skip=True)
# double_y_zhexian(pseudo_acces_1, ulb_ratios_1, pseudo_acces_2, ulb_ratios_2, 'ablation_mix.png')
double_y_zhexian(sup_losses_1, sup_losses_2, unsup_losses_1, unsup_losses_2, 'ablation_mix_loss.png')
print(111)
